-
Notifications
You must be signed in to change notification settings - Fork 468
[NPU]: Add NPU support for the embedding #1028
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Hi @Tcc0403, could you please help me review my code? |
Tcc0403
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems the current implementation is quite inefficient. I've left some comments about some possible issues it might have.
| ) | ||
|
|
||
|
|
||
| def get_optimal_block_size(total_elements, is_backward: bool): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what does is_backward do?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, at first I intended to distinguish the forward and backward directions. Later, I realized their logic was quite similar and I forgot to delete it.
| @triton.jit | ||
| def embedding_forward_kernel( | ||
| embeddings_ptr, | ||
| indices_ptr, | ||
| output_ptr, | ||
| total_elements, | ||
| n_elements, | ||
| embedding_dim: tl.constexpr, | ||
| BLOCK_SIZE: tl.constexpr, | ||
| NUM_STAGES: tl.constexpr, | ||
| ): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the original implementation with 2 block sizes for tile shape is more readable and more efficient.
persistant grid loop is fine, but the way this kernel loading embedding seems to be uncoalesced at some point.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For instance, there will be some dim_idx not consecutive if BLOCK_SIZE is not multiple of embedding_dim. It will make the second tl.load trying to access different rows within a warp, as well as the last store.
Make these offsets created with 2d block size is more readable and efficient since we can avoid the uncoalesced access mentioned above.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have changed it to 2D block. After testing, it has indeed shown much better performance. The issues mentioned below have also been fixed. Could you please review it for me again?
| tile_shapes = compute_default_tiling_strategy( | ||
| safety_margin=0.9, dtype_size=4, memory_multiplier=multiplier, shapes=((total_elements,),), tiling_dims=(0,) | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dtype_size should be embedding.dtype?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Modified
| block_size = tile_shapes[0][0] | ||
| return block_size | ||
| else: | ||
| return triton.next_power_of_2(total_elements) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think fallback value should be workable, triton.next_power_of_2(total_elements) is too large.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Modified
| embeddings_ptr + embedding_offsets, | ||
| mask=final_mask, | ||
| other=0.0, | ||
| ).to(tl.float32) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
any consideration why we need to upcast it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Modified
|
Could you attach the benchmark results for reference? |
Currently, compared to the previous version, the performance has improved by 4 to 5 times. However, it still has a significant difference compared to HuggingFace. But I attempted to use the original GPU code (only addressing the UB issue), and the performance was nearly the same (the results are shown below). |
Tcc0403
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm fine with merging this PR since it's an experimental operator and isn’t used in any patching path. That said, we should probably open a performance issue for this kernel and track it for future improvements.
You're right. In fact, we do have plans to improve the performance. Currently, we need to first support these operators on the NPU and explore ways to optimize the performance as much as possible. |
|
Could you open an issue with benchmarking results so we can track this performance problem and allow future contributors to work on it? |
Sure! #1036 |
|
Thank you! |


Summary
Add NPU support for the embedding.
Testing Done
I tested swiglu by following method and all cases passed:
python benchmark/scripts/benchmark_embedding.pypytest -v test/transformers/test_embedding.pymake testto ensure correctnessmake checkstyleto ensure code stylemake test-convergenceto ensure convergence